import os
import pandas as pd


def generate_csv(root_dir, mode):
    if 'train' in mode:
        videos_path = os.path.join(root_dir, 'Train_files_crops')
        # videos_path = os.path.join(root_dir, 'Train_files_MSR')
    elif 'val' == mode:
        videos_path = os.path.join(root_dir, 'Dev_files_crops')
        # videos_path = os.path.join(root_dir, 'Dev_files_MSR')
    else:
        videos_path = os.path.join(root_dir, 'Test_files_crops')
        # videos_path = os.path.join(root_dir, 'Test_files_MSR')
    fold_data = []
    for video_name in os.listdir(videos_path):
        video_path = os.path.join(videos_path, video_name)
        for image in os.listdir(video_path):
            video_name_split = video_name.split('_')
            assert len(video_name_split) == 4
            if 'real' in mode or 'fake' in mode:
                if 'real' in mode and int(video_name_split[-1]) == 1:
                    label = 0
                    fold_data.append([video_name, image, label])
                elif 'fake' in mode and int(video_name_split[-1]) != 1:
                    label = 1
                    fold_data.append([video_name, image, label])
            else:
                label = 0 if int(video_name_split[-1]) == 1 else 1
                fold_data.append([video_name, image, label])
    columns = ["video", "file", "label"]
    pd.DataFrame(fold_data, columns=columns).to_csv('../data/data_{}.csv'.format(mode), index=False)
    # pd.DataFrame(fold_data, columns=columns).to_csv('../data/data_{}_MSR.csv'.format(mode), index=False)


def main():
    # root_dir = '/data/heyan/Datasets/Face/oulu_256'  # RGB images
    # root_dir = '/home/shaohua/data2/Datasets/Oulu_NPU_256'  # MSR images
    root_dir = '/home/shaohua/data2/Datasets/Face_Anti_Spoofing/Oulu_NPU'

    # generate_csv(root_dir, 'train')
    # generate_csv(root_dir, 'val')
    # generate_csv(root_dir, 'test')

    generate_csv(root_dir, 'train_real')
    generate_csv(root_dir, 'train_fake')


if __name__ == '__main__':
    main()
